"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script creates the flow stability clustering of several community events and
produces Fig. S2.


"""

import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import pickle
import numpy as np


from itertools import combinations, product
from collections import Counter
from scipy.linalg import expm, block_diag

from scipy import integrate

def Round_To_n(x, n):
    return round(x, -int(np.floor(np.sign(x) * np.log10(abs(x)))) + n)

import matplotlib.pyplot as plt
import matplotlib as mpl
plt.style.use('alex_paper')

from TemporalStability import norm_var_information, Clustering



mpl.rcParams['lines.linewidth'] = 1.7

figdir = '../figures/community_events'

datadir = '../paper_data/synthtempnet/community_events/'

os.makedirs(datadir, exist_ok=True)
os.makedirs(figdir, exist_ok=True)

net_file = os.path.join(datadir,'net.pickle')

raise Exception
#%% functions
def get_A(t, t_end=1):
    raise NotImplementedError()

def get_L(get_A, t, t_end=1, **kwargs):
    
    A = get_A(t,t_end=t_end, **kwargs)
    degs = A.sum(1)
    
    zero_degs = np.where(degs==0)[0]
    if zero_degs.size>0:
        degs[zero_degs]=1
        self_loops = np.zeros_like(degs)
        self_loops[zero_degs]=1
        np.fill_diagonal(A, A.diagonal()+self_loops)
        
    
    return np.eye(degs.size) - np.diag(1/degs)@A


def integrate_matrix(A,a,b,shape,num_points=250):
    """ assumes symetric matrix"""
    ts=np.linspace(a,b,num_points)
    I = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(i, shape[1]):
            I[i,j] = integrate.simpson([A(t)[i,j] for t in ts], ts)
            
    I = I + I.T - np.diag(I.diagonal())
            
    return I



def compute_I(T_fct,size, t_end=1, l=1, p0=None, **kwargs):
    
    if p0 is None:
        p0 = np.ones(size)/size
    P0 = np.diag(p0)
    
    O = np.outer(p0,p0)
    
    def S(t):
        T_t = T_fct(t, t_end=t_end, l=l, **kwargs)
        
        p_t = p0 @ T_t
        
        # assert (p_t==0).sum() == 0
        p_t_nz = p_t.copy()
        p_t_nz[p_t_nz==0] = 1
        
        S = P0 @ T_t @ np.diag(np.sqrt(1/p_t_nz))
        S = S @ S.T
        S = S - O
        return S
    
    return integrate_matrix(S, 0, t_end, (size,size))
    


        

#%% birth/death


def get_A_death(t, t_end=1, num_nodes=5):
    
    A_dead= np.zeros((num_nodes,num_nodes))

    A_birth = np.ones((num_nodes,num_nodes))
    np.fill_diagonal(A_birth,0)
    

    if t>=0 and t<=t_end/2:
        return A_birth
    elif t>t_end/2 and t<=t_end:
        return A_dead
    else:
        return NotImplementedError()
    


def get_T_forw_death(t, t_end=1, l=1,  num_nodes=5):
    
    L_f = get_L(get_A_death, t, t_end=t_end,  num_nodes=num_nodes)
    
    if t>=0 and t<=t_end/2:
        return expm(-l*t*L_f)
    elif t>t_end/2 and t<=t_end:
        return get_T_forw_death(t_end/2,t_end=t_end,
                                l=l, num_nodes=num_nodes)@expm(-l*(t-t_end/2)*L_f)
    else:
        return NotImplementedError()
    
def get_T_back_death(t, t_end=1, l=1,  num_nodes=5):
    
    if t>t_end/1000:
        epsilon = t_end/1000
    else:
        epsilon = 0
        
    L_b = get_L(get_A_death, t_end-t+epsilon, t_end=t_end,  num_nodes=num_nodes)
    
    if t>=0 and t<=t_end/2:
        return expm(-l*t*L_b)
    elif t>t_end/2 and t<=t_end:
        return get_T_back_death(t_end/2,t_end=t_end,
                                l=l, num_nodes=num_nodes)@expm(-l*(t-t_end/2)*L_b)
    else:
        return NotImplementedError()
    
    
#%% growth contraction

def get_A_growth(t,t_end=1, n_group1 = 3, n_group2 = 3):

    group1 = np.arange(n_group1)
    group2 = np.arange(n_group1,n_group1+n_group2)
    
    num_nodes = n_group1 + n_group2
    
    A_group2 = np.ones((n_group2,n_group2))
    np.fill_diagonal(A_group2,0)
    
    A_small = np.zeros((num_nodes,num_nodes))
    A_small[np.ix_(group1,group1)] = np.eye(n_group1)
    A_small[np.ix_(group2,group2)] = A_group2
    
    A_big = np.ones((num_nodes,num_nodes))
    np.fill_diagonal(A_big,0)

    if t>=0 and t<=t_end/2:
        return A_small
    elif t>t_end/2 and t<=t_end:
        return A_big
    else:
        raise NotImplementedError()
        

def get_T_forw_growth(t, t_end=1, l=1, n_group1 = 3, n_group2 = 3):
    
    L_f = get_L(get_A_growth, t, t_end=t_end, n_group1 = n_group1, n_group2 = n_group2)
    
    if t>=0 and t<=t_end/2:
        return expm(-l*t*L_f)
    elif t>t_end/2 and t<=t_end:
        return get_T_forw_growth(t_end/2,t_end=t_end,l=l,
                                n_group1=n_group1, n_group2=n_group2)@expm(-l*(t-t_end/2)*L_f)
    else:
        return NotImplementedError()
    
def get_T_back_growth(t, t_end=1, l=1, n_group1 = 3, n_group2 = 3):
    
    if t>t_end/1000:
        epsilon = t_end/1000
    else:
        epsilon = 0
        
    L_b = get_L(get_A_growth, t_end-t+epsilon, t_end=t_end, n_group1 = n_group1, n_group2 = n_group2)
    
    if t>=0 and t<=t_end/2:
        return expm(-l*t*L_b)
    elif t>t_end/2 and t<=t_end:
        return get_T_back_growth(t_end/2,t_end=t_end,l=l,
                                n_group1=n_group1, n_group2=n_group2)@expm(-l*(t-t_end/2)*L_b)
    else:
        return NotImplementedError()



#%% merge/split
    
def get_A_merge(t, t_end=1, n_group1=3, n_group2=3):

    group1 = np.arange(n_group1)
    group2 = np.arange(n_group1,n_group1+n_group2)

    num_nodes =     n_group1 + n_group2
    A_group1 = np.ones((n_group1,n_group1))
    np.fill_diagonal(A_group1,0)
    
    A_group2 = np.ones((n_group2,n_group2))
    np.fill_diagonal(A_group2,0)
    
    A_split = np.zeros((num_nodes,num_nodes))
    A_split[np.ix_(group1,group1)] = A_group1
    A_split[np.ix_(group2,group2)] = A_group2
    
    
    A_merged = np.ones((num_nodes,num_nodes))
    np.fill_diagonal(A_merged,0)
    
    if t>=0 and t<=t_end/2:
        return A_split
    elif t>t_end/2 and t<=t_end:
        return A_merged
    else:
        raise NotImplementedError()

def get_T_forw_merge(t, t_end=1, l=1, n_group1 = 3, n_group2 = 3):
    
    L_f = get_L(get_A_merge, t, t_end=t_end, n_group1 = n_group1, n_group2 = n_group2)
    
    if t>=0 and t<=t_end/2:
        return expm(-l*t*L_f)
    elif t>t_end/2 and t<=t_end:
        return get_T_forw_merge(t_end/2,t_end=t_end,l=l,
                                n_group1=n_group1, n_group2=n_group2)@expm(-l*(t-t_end/2)*L_f)
    else:
        return NotImplementedError()
    
def get_T_back_merge(t, t_end=1, l=1, n_group1 = 3, n_group2 = 3):
    
    if t>t_end/1000:
        epsilon = t_end/1000
    else:
        epsilon = 0
        
    L_b = get_L(get_A_merge, t_end-t+epsilon, t_end=t_end, n_group1 = n_group1, n_group2 = n_group2)
    
    if t>=0 and t<=t_end/2:
        return expm(-l*t*L_b)
    elif t>t_end/2 and t<=t_end:
        return get_T_back_merge(t_end/2,t_end=t_end,l=l,
                                n_group1=n_group1, n_group2=n_group2)@expm(-l*(t-t_end/2)*L_b)
    else:
        return NotImplementedError()



#%% continue
num_nodes = 5

def get_A_continue(t, t_end=1, num_nodes=6):
    
    A_cont = np.ones((num_nodes,num_nodes))
    np.fill_diagonal(A_cont,0)


    return A_cont
    

def get_T_forw_continue(t,t_end=1, l=1, num_nodes=6):
    
    L = get_L(get_A_continue, t, t_end=t_end, num_nodes=num_nodes)

    return expm(-l*t*L)

def get_T_back_continue(t,t_end=1, l=1, num_nodes=6):
    
    L = get_L(get_A_continue, t, t_end=t_end, num_nodes=num_nodes)

    return expm(-l*t*L)
    
#%% resurgence
    

def get_A_resurg(t, t_end=1, num_nodes=5):
    
    A_all = np.ones((num_nodes,num_nodes))
    np.fill_diagonal(A_all,0)
    
    A_disapear = np.zeros_like(A_all)
    
    if t>= 0 and t<=t_end/3:
        return A_all
    elif t> t_end/3 and t <= 2*t_end/3:
        return A_disapear
    elif t> 2*t_end/3 and t <= 3*t_end/3:
        return A_all
    

def get_T_forw_resurg(t,t_end=1,l=1,num_nodes=6):

    
    L = get_L(get_A_resurg, t, t_end=t_end, num_nodes=num_nodes)
    
    if t>=0 and t<=t_end/3:
        return expm(-l*t*L)
    elif t>t_end/3 and t<=2*t_end/3:
        return get_T_forw_resurg(t_end/3,t_end=t_end,
                                 l=l,num_nodes=num_nodes)@expm(-l*(t-t_end/3)*L)
    elif t>2*t_end/3 and t<=3*t_end/3:
        return get_T_forw_resurg(2*t_end/3,t_end=t_end,
                                 l=l,num_nodes=num_nodes)@expm(-l*(t-2*t_end/3)*L)
    else:
        return NotImplementedError()
    
def get_T_back_resurg(t,t_end=1,l=1,num_nodes=6):
    # it's time symetric

    return get_T_forw_resurg(t,t_end=t_end,l=l,num_nodes=num_nodes)
  
#%% theseus's boat
    
def get_A_theseus(t, t_end, boat_size=3):

    num_nodes = boat_size+4
    
    n_group1 = np.arange(boat_size)
    n_group2 = np.arange(1,boat_size+1)
    n_group3 = np.arange(2,boat_size+2)
    n_group4 = np.arange(3,boat_size+3)
    n_group5 = np.arange(4,boat_size+4)


    A_boat = np.ones((boat_size,boat_size))
    np.fill_diagonal(A_boat,0)


    A_b1 = np.zeros((num_nodes,num_nodes))
    A_b1[np.ix_(n_group1,n_group1)] = A_boat
    A_b2 = np.zeros((num_nodes,num_nodes))
    A_b2[np.ix_(n_group2,n_group2)] = A_boat
    A_b3 = np.zeros((num_nodes,num_nodes))
    A_b3[np.ix_(n_group3,n_group3)] = A_boat
    A_b4 = np.zeros((num_nodes,num_nodes))
    A_b4[np.ix_(n_group4,n_group4)] = A_boat
    A_b5 = np.zeros((num_nodes,num_nodes))
    A_b5[np.ix_(n_group5,n_group5)] = A_boat
    A_b6 = np.zeros((num_nodes,num_nodes))
    A_b6[np.ix_(n_group5,n_group5)] = A_boat
    A_b6[np.ix_(n_group1,n_group1)] = A_boat
    
    if t>=0 and t<=t_end*2/8:
        return A_b1
    elif t>t_end*2/8 and t<=3*t_end/8:
        return A_b2
    elif t>3*t_end/8 and t<=4*t_end/8:
        return A_b3
    elif t>4*t_end/8 and t<=5*t_end/8:
        return A_b4
    elif t>5*t_end/8 and t<=6*t_end/8:
        return A_b5
    elif t>6*t_end/8 and t<=8*t_end/8:
        return A_b6
    else:
        raise NotImplementedError()
        
    

def get_T_forw_theseus(t, t_end=1, l=1, boat_size=3):
    
    L_f = get_L(get_A_theseus, t, t_end=t_end,  boat_size=boat_size)
    
    if t>=0 and t<=t_end*2/8:
        
        return expm(-l*t*L_f)
    
    elif t>2*t_end/8 and t <= 3*t_end/8:
        
        return get_T_forw_theseus(2*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-2*t_end/8)*L_f)
    
    elif t>3*t_end/8 and t <= 4*t_end/8:
        
        return get_T_forw_theseus(3*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-3*t_end/8)*L_f)
    
    elif t>4*t_end/8 and t <= 5*t_end/8:
        
        return get_T_forw_theseus(4*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-4*t_end/8)*L_f)
    
    elif t>5*t_end/8 and t <= 6*t_end/8:
        
        return get_T_forw_theseus(5*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-5*t_end/8)*L_f)
    
    elif t>6*t_end/8 and t <= 8*t_end/8:
        
        return get_T_forw_theseus(6*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-6*t_end/8)*L_f)    
    
    else:
        raise NotImplementedError()
    
def get_T_back_theseus(t, t_end=1, l=1, boat_size=3):
    
    if t>t_end/1000:
        
        epsilon = t_end/1000 # because boundary conditions were meant for the forward case
    else:
        epsilon = 0
    
    
    L_b = get_L(get_A_theseus, t_end-t+epsilon, t_end=t_end, boat_size=boat_size)
    
    if t>=0 and t<=t_end*2/8:
        
        return expm(-l*t*L_b)
    
    elif t>t_end*2/8 and t <= 3*t_end/8:
        
        return get_T_back_theseus(2*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-2*t_end/8)*L_b)
    
    elif t>3*t_end/8 and t <= 4*t_end/8:
        
        return get_T_back_theseus(3*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-3*t_end/8)*L_b)
    
    elif t>4*t_end/8 and t <= 5*t_end/8:
        
        return get_T_back_theseus(4*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-4*t_end/8)*L_b)
    
    elif t>5*t_end/8 and t <= 6*t_end/8:
        
        return get_T_back_theseus(5*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-5*t_end/8)*L_b)
    
    elif t>6*t_end/8 and t <= 8*t_end/8:
        
        return get_T_back_theseus(6*t_end/8, 
                                   t_end=t_end, l=l,
                                   boat_size=boat_size)@expm(-l*(t-6*t_end/8)*L_b)    
    
    else:
        raise NotImplementedError()    

    
    
#%% run complete lamba scan

#for each separate event

lambas = [Round_To_n(l, 3) for l in np.logspace(-1,2,num=30)]
    
num_repeat = 10

ev_res = dict()

t_end=1
n_per_group_nt=3
num_sub_nodes=6
n_group1=3
n_group2=3
boat_size=3

kwargs= {"n_per_group_nt":n_per_group_nt, 
         "num_sub_nodes":num_sub_nodes,
         "n_group1":n_group1, 
         "n_group2":n_group2, 
         "boat_size":boat_size}



for event_type, get_T_forw_func, get_T_back_func, size, kwargs in\
        [ ('birth_death', get_T_forw_death, get_T_back_death, num_sub_nodes, {'num_nodes':num_sub_nodes}),
          ('contract', get_T_forw_growth, get_T_back_growth, n_group1+n_group2, {'n_group1':n_group1, 'n_group2':n_group2}),
          ('merge_split',get_T_forw_merge, get_T_back_merge, n_group1+n_group2, {'n_group1':n_group1, 'n_group2':n_group2}),
          ('continue', get_T_forw_continue, get_T_back_continue, num_sub_nodes, {'num_nodes':num_sub_nodes}),
          ('resurgence', get_T_forw_resurg, get_T_back_resurg, num_sub_nodes, {'num_nodes':num_sub_nodes}),
          ('theseus', get_T_forw_theseus, get_T_back_theseus, boat_size+4, {'boat_size':boat_size})]:
    
    
            
    print(event_type)
    
    clust_res = dict()
    
    Is_forw = dict()
    Is_backw = dict()
    
    T_forw = dict()
    T_backw = dict()
    
    ev_res[event_type] = dict()
    
    
    p0 = np.ones(size+1)/(size+1)
    
    for lamba in lambas:
    
        
        print('\n+++++ lamba = ', lamba)
        
        clusters_forw = []
        stabilites_forw = []
        
        clusters_backw = []
        stabilites_backw = []
        
        
        
        #add a singleton node for easier detection of full clusters
        def T_forw_fct(t, t_end=t_end, l=lamba, **kwargs):
            return block_diag(get_T_forw_func(t, t_end=t_end, l=lamba, **kwargs), np.eye(1))
        
        def T_back_fct(t, t_end=t_end, l=lamba, **kwargs):
            return block_diag(get_T_back_func(t, t_end=t_end, l=lamba, **kwargs), np.eye(1))  
        
        If = compute_I(T_forw_fct, size+1, t_end=t_end, l=lamba, p0=p0, **kwargs)
        Ib = compute_I(T_back_fct, size+1, t_end=t_end, l=lamba, p0=p0, **kwargs)
        
        for i in range(num_repeat):
    
            print('**** PID ', os.getpid(), 'n ', i)
                
            clustering_forw = Clustering(S=If)
            clustering_backw = Clustering(S=Ib)
    
            
            clustering_forw.find_louvain_clustering()
            
            clustering_forw.partition.cluster_list.pop() #remove singleton
        
            clusters_forw.append(clustering_forw.partition.cluster_list)
            stabilites_forw.append(clustering_forw.compute_stability())
            
            clustering_backw.find_louvain_clustering()
            
            clustering_backw.partition.cluster_list.pop() #remove singleton
            
            clusters_backw.append(clustering_backw.partition.cluster_list)
            stabilites_backw.append(clustering_backw.compute_stability())       
        
        nvarinf_forw = np.mean([norm_var_information(c1,c2) for c1,c2 in combinations(clusters_forw,2)])
        
        clust_counter_forw = Counter([tuple(sorted([tuple(sorted(c)) for c in clust])) \
                                     for clust in clusters_forw])
        
        best_cluster_forw = clusters_forw[np.argmax(stabilites_forw)]
        
        best_stab_forw = max(stabilites_forw)
        
        avg_stab_forw = np.mean(stabilites_forw)
            
        avg_nclust_forw = np.mean([len(c) for c in clusters_forw])    
        
        best_cluster_nclust_forw = len(best_cluster_forw)
    
    
        nvarinf_backw = np.mean([norm_var_information(c1,c2) for c1,c2 in combinations(clusters_backw,2)])
        
        clust_counter_backw = Counter([tuple(sorted([tuple(sorted(c)) for c in clust])) \
                                     for clust in clusters_backw])
        
        best_cluster_backw = clusters_backw[np.argmax(stabilites_backw)]
        
        best_stab_backw = max(stabilites_backw)
        
        avg_stab_backw = np.mean(stabilites_backw)
            
        avg_nclust_backw = np.mean([len(c) for c in clusters_backw])    
        
        best_cluster_nclust_backw = len(best_cluster_backw)
    
        res = dict()
        
      
        res['clust_counter_forw'] = clust_counter_forw
        res['stabilites_forw'] = stabilites_forw
        res['nvarinf_forw'] = nvarinf_forw
        res['avg_stab_forw'] = avg_stab_forw
        res['avg_nclust_forw'] = avg_nclust_forw
        res['best_cluster_forw'] = best_cluster_forw
        res['best_stab_forw'] = best_stab_forw
        res['best_cluster_nclust_forw'] = best_cluster_nclust_forw
        res['clusters_forw'] = clusters_forw
        
        res['clust_counter_backw'] = clust_counter_backw
        res['stabilites_backw'] = stabilites_backw
        res['nvarinf_backw'] = nvarinf_backw
        res['avg_stab_backw'] = avg_stab_backw
        res['avg_nclust_backw'] = avg_nclust_backw
        res['best_cluster_backw'] = best_cluster_backw
        res['best_stab_backw'] = best_stab_backw
        res['best_cluster_nclust_backw'] = best_cluster_nclust_backw
        res['clusters_backw'] = clusters_backw
        
        clust_res[lamba] = res
        
        Is_forw[lamba] = If
        Is_backw[lamba] = Ib
                
        T_forw[lamba] = T_forw_fct(t_end, t_end=t_end, l=lamba, **kwargs)[np.ix_(range(size),range(size))]
        T_backw[lamba] = T_back_fct(t_end, t_end=t_end, l=lamba, **kwargs)[np.ix_(range(size),range(size))]
    
    varinf_lamb_backw = []
    varinf_lamb_forw = []
    
    for k in range(len(lambas)-1):
        varinf_lamb_backw.append(np.mean([norm_var_information(c1,c2) for c1,c2 in \
                                    product(clust_res[lambas[k]]['clusters_backw'],
                                            clust_res[lambas[k+1]]['clusters_backw'])]))
        varinf_lamb_forw.append(np.mean([norm_var_information(c1,c2) for c1,c2 in \
                                    product(clust_res[lambas[k]]['clusters_forw'],
                                            clust_res[lambas[k+1]]['clusters_forw'])]))
        
    ev_res[event_type]['clust_res'] = clust_res
    ev_res[event_type]['Is_forw'] = Is_forw
    ev_res[event_type]['Is_backw'] = Is_backw
    ev_res[event_type]['varinf_lamb_backw'] = varinf_lamb_backw
    ev_res[event_type]['varinf_lamb_forw'] = varinf_lamb_forw
    ev_res[event_type]['T_forw'] = T_forw
    ev_res[event_type]['T_backw'] = T_backw
    
#%% save results


with open(os.path.join(datadir, 'ev_res.pickle'), 'wb') as fopen:
    pickle.dump(ev_res, fopen)
    
    
#%% load
with open(os.path.join(datadir, 'ev_res.pickle'), 'rb') as fopen:
    ev_res = pickle.load(fopen)    


#%% plot for paper figure



# birth death
event_type = 'birth_death'
# ind_lamba =13 # all singletons
ind_lamba =16 # full to singletons

# contract
event_type = 'contract'
ind_lamba =13 # [0],[1],[2], [3,4,5] -> singletons
# ind_lamba =14 # [0],[1],[2], [3,4,5] -> [0],[1],[2], [3,4,5]
# ind_lamba =15 # [0],[1],[2], [3,4,5] -> full

#merge split
event_type = 'merge_split'
ind_lamba =13 # two groups -> singletons
ind_lamba =15 # two groups -> one group
ind_lamba =14 # two groups -> two groups
# ind_lamba =5 # singletons

# # continue
event_type = 'continue'
ind_lamba =10 # singletons
ind_lamba =20 # one comm

# # resurgence
event_type = 'resurgence'
# ind_lamba =10 # singletons
ind_lamba = 20 # one groupe

# # theseus 
event_type = 'theseus'
ind_lamba = 16 # [0,1,2] -> [0,1,2],[3],[4,5,6]
ind_lamba = 25 # [0,1,2,4] -> [0,1,2],[3],[4,5,6]
# ind_lamba = 6 # singletons

lambas = np.sort(np.array(list(ev_res[event_type]['clust_res'].keys())))
lamba = lambas[ind_lamba]
print(lamba)

fig = plt.figure(figsize=(8,4))
gs = fig.add_gridspec(1,8,
                      hspace=1,
#                      wspace=0.5, 
#                      bottom=0.3,
#                      left=0.02,
#                      right=0.95
                      )
# forw and back partitions
forw_lines = []

for c in sorted(ev_res[event_type]['clust_res'][lamba]['best_cluster_forw']):
    forw_lines += list(c) + [None]

backw_lines = []

backshifted = False
for c in sorted(ev_res[event_type]['clust_res'][lamba]['best_cluster_backw']):
    backw_lines += list(c) + [None]
    if (np.diff(list(c)) > 1).any():
        backshifted = True


xs_forw = [0]*len(forw_lines)
xs_backw = [0]*len(backw_lines)
if backshifted:
    xs_backw = []
    for x,c in enumerate(sorted(ev_res[event_type]['clust_res'][lamba]['best_cluster_backw'])):
        xs_backw += [x*0.5]*len(c) + [None]

ax1 = fig.add_subplot(gs[0,0],)
_ = ax1.plot(xs_forw,forw_lines,'o-k',markersize=5,lw=2)
_ = ax1.axis('off')
ax1.set_xlim([-0.5,1])

ax4 = fig.add_subplot(gs[0,1],sharey=ax1)
_ = ax4.plot(xs_backw,backw_lines,'o-k',markersize=5,lw=2)
_ = ax4.axis('off')
ax4.set_xlim([-0.5,1])

_ = ax1.text(-0.25,0.5,'Forward partition', 
         rotation=90, 
         transform=ax1.transAxes,  horizontalalignment='center',
      verticalalignment='center')

_ = ax4.text(1.25,0.5,'Backward partition', 
         rotation=270, 
         transform=ax4.transAxes,  horizontalalignment='center',
      verticalalignment='center')
    
ylims = ax1.get_ylim()
ax1.set_ylim(ylims[::-1])

# proba graph
Tf = ev_res[event_type]['T_forw'][lamba]

import networkx as nx
Tgraph = nx.DiGraph()

#forw clust to back clust
node_labels = {}
for s, cf in enumerate(sorted(ev_res[event_type]['clust_res'][lamba]['best_cluster_forw'])):
    print(s)
    
    Tgraph.add_node(f'f{s}', size=len(cf))
    node_labels[f'f{s}']=f'{[n+1 for n in list(cf)]}'
    
    for t, cb in enumerate(sorted(ev_res[event_type]['clust_res'][lamba]['best_cluster_backw'])):
        
        Tgraph.add_node(f'b{t}', size=len(cb))
        
        node_labels[f'b{t}']=f'{[n+1 for n in list(cb)]}'
    
        Tsub = Tf[np.ix_(list(cf),list(cb))]
        
        Tgraph.add_edge(f'f{s}', f'b{t}', weight=Tf[np.ix_(list(cf),list(cb))].sum()/len(cf))
    


nodesizefact=20
edgewidthfact=5

fnodes = [n for n in Tgraph.nodes() if n.startswith('f')]
bnodes = [n for n in Tgraph.nodes() if n.startswith('b')]

pos = nx.layout.bipartite_layout(Tgraph, nodes=fnodes)

# sort y pos
fposs = [pos[n] for n in fnodes]
bposs = [pos[n] for n in bnodes]

pos ={n:p for n,p in zip(sorted(fnodes), sorted(fposs, key=lambda x:x[1], reverse=True))}
pos.update({n:p for n,p in zip(sorted(bnodes), sorted(bposs, key=lambda x:x[1], reverse=True))})

shift=0.15
pos_labels = {}
for n,p in pos.items():
    if n.startswith('f'):
        pos_labels[n] = p + np.array([-shift,0])
    else:
        pos_labels[n] = p + np.array([+shift,0])


ax9 = fig.add_subplot(gs[0,2:])

nx.draw_networkx_edges(Tgraph, pos, alpha=0.9,
                       edgelist=[(s,t) for s,t,d in Tgraph.edges(data=True) if d['weight']>0.01],
                       width=[d['weight']*edgewidthfact for _,_,d in Tgraph.edges(data=True) if d['weight']>0.01], 
                       arrowsize=10,
                       connectionstyle='arc3, rad = -0.05',
                       edge_color="k",
                       node_size=1000,
                       ax=ax9)
nx.draw_networkx_nodes(Tgraph, pos, 
                       node_size=[d['size']*nodesizefact for _,d in Tgraph.nodes(data=True)],
                       node_color="#210070", alpha=0.9,
                       ax=ax9)
# label_options = {"ec": "k", "fc": "white", "alpha": 0.7}
nx.draw_networkx_labels(Tgraph, pos_labels, 
                        labels={n:lab for n,lab in node_labels.items() if n.startswith('f')},
                                                    font_size=10,
                        horizontalalignment='right',
                        ax=ax9)
nx.draw_networkx_labels(Tgraph, pos_labels, labels={n:lab for n,lab in node_labels.items() if n.startswith('b')},
                        font_size=10,
                        horizontalalignment='left',
                        ax=ax9)


ax9.margins(x=0.45)
ax9.patch.set_visible(False)
ax9.axis('off')

_ = fig.suptitle(event_type + r', $\tau_w$=' + '{:.3f}'.format(1/lamba))

#%%
plt.savefig(os.path.join('../figures/paper_figures/', 'partition_community_event_' + event_type + \
                         '_tau_w{:.3f}'.format(1/lamba).replace('.','dot') +  '.pdf'))
    

    
